;;; guile-openai --- An OpenAI API client for Guile
;;; Copyright © 2023 Andrew Whatson <whatson@tailcall.au>
;;;
;;; This file is part of guile-openai.
;;;
;;; guile-openai is free software: you can redistribute it and/or modify
;;; it under the terms of the GNU Affero General Public License as
;;; published by the Free Software Foundation, either version 3 of the
;;; License, or (at your option) any later version.
;;;
;;; guile-openai is distributed in the hope that it will be useful, but
;;; WITHOUT ANY WARRANTY; without even the implied warranty of
;;; MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
;;; Affero General Public License for more details.
;;;
;;; You should have received a copy of the GNU Affero General Public
;;; License along with guile-openai.  If not, see
;;; <https://www.gnu.org/licenses/>.

(define-module (openai utils foreign)
  #:use-module (ice-9 match)
  #:use-module (ice-9 format)
  #:use-module (ice-9 vlist)
  #:use-module (srfi srfi-1)
  #:use-module (srfi srfi-9)
  #:use-module (srfi srfi-9 gnu)
  #:use-module ((system foreign) #:prefix ffi:)
  #:use-module ((system foreign) #:select (define-wrapped-pointer-type))
  #:use-module (system foreign-library)
  #:export (c-type?
            c-type-name
            c-type-size

            int8 uint8 int16 uint16 int32 uint32 int64 uint64
            float double complex-double complex-float
            int unsigned-int long unsigned-long short unsigned-short
            size_t ssize_t ptrdiff_t intptr_t uintptr_t
            void pointer cstring bool

            define-foreign-type
            define-foreign-arg-type
            define-foreign-return-type

            define-foreign-enum-type
            define-foreign-pointer-type

            define-foreign-library
            define-foreign-function
            define-foreign-functions))

;;; C type marshalling

(define-record-type <c-type>
  (%make-c-type name repr wrapper unwrapper)
  c-type?
  (name c-type-name)
  (repr c-type-repr)
  (wrapper c-type-wrapper)
  (unwrapper c-type-unwrapper))

(define* (make-c-type name repr #:key wrap-result unwrap-args)
  (%make-c-type name repr wrap-result unwrap-args))

(define* (print-c-type type #:optional port)
  (format port "#<c-type ~a ~a>"
          (c-type-name type)
          (c-type-name (get-base-type (c-type-repr type)))))

(define (c-type-size type)
  (ffi:sizeof (c-type-repr type)))

(set-record-type-printer! <c-type> print-c-type)

(define-syntax-rule (define-foreign-type type-name base args ...)
  (define type-name
    (make-c-type (symbol->string 'type-name)
                 (c-type-repr base)
                 args ...)))

(define-syntax-rule (define-foreign-arg-type type-name base unwrapper)
  (define-foreign-type type-name base #:unwrap-args unwrapper))

(define-syntax-rule (define-foreign-return-type type-name base wrapper)
  (define-foreign-type type-name base #:wrap-result wrapper))

;;; Base types

(define %base-types vlist-null)

(define (register-base-type! type)
  (let ((repr (c-type-repr type)))
    (unless (has-base-type? repr)
      (set! %base-types (vhash-consv repr type %base-types)))))

(define (has-base-type? repr)
  (and (vhash-assv repr %base-types) #t))

(define (get-base-type repr)
  (match (vhash-assv repr %base-types)
    ((_ . type) type)))

(define-syntax-rule (define-base-type type-name repr)
  (begin
    (define type-name
      (make-c-type (symbol->string 'type-name) repr
                   #:wrap-result (lambda (res . _) res)
                   #:unwrap-args (lambda (arg) arg)))
    (register-base-type! type-name)))

(define-base-type int8           ffi:int8)
(define-base-type uint8          ffi:uint8)
(define-base-type int16          ffi:int16)
(define-base-type uint16         ffi:uint16)
(define-base-type int32          ffi:int32)
(define-base-type uint32         ffi:uint32)
(define-base-type int64          ffi:int64)
(define-base-type uint64         ffi:uint64)
(define-base-type float          ffi:float)
(define-base-type double         ffi:double)
(define-base-type complex-double ffi:complex-double)
(define-base-type complex-float  ffi:complex-float)
(define-base-type int            ffi:int)
(define-base-type unsigned-int   ffi:unsigned-int)
(define-base-type long           ffi:long)
(define-base-type unsigned-long  ffi:unsigned-long)
(define-base-type short          ffi:short)
(define-base-type unsigned-short ffi:unsigned-short)
(define-base-type size_t         ffi:size_t)
(define-base-type ssize_t        ffi:ssize_t)
(define-base-type ptrdiff_t      ffi:ptrdiff_t)
(define-base-type intptr_t       ffi:intptr_t)
(define-base-type uintptr_t      ffi:uintptr_t)
(define-base-type void           ffi:void)
(define-base-type pointer        '*)

;;; Common types

(define-foreign-type cstring pointer
  #:wrap-result (lambda (ptr . _) (ffi:pointer->string ptr))
  #:unwrap-args ffi:string->pointer)

(define-foreign-type bool int
  #:wrap-result (lambda (int . _) (not (zero? int)))
  #:unwrap-args (lambda (bool) (if bool 1 0)))

;;; Enum types

(define-syntax-rule (define-foreign-enum-type enum-name enum-base
                      enumerator? enumerator-list
                      int->enumerator enumerator->int
                      (enumerator ...))
  (begin
    (define (enumerator? sym)
      (and (enumerator->int sym) #t))
    (define (enumerator-list)
      (%dfe-enum-symbols (enumerator ...)))
    (define enumerator->int
      (let ((lookup (alist->vhash (map cons
                                       (%dfe-enum-symbols (enumerator ...))
                                       (%dfe-enum-values (enumerator ...)))
                                  hashq)))
        (lambda (sym)
          (and=> (vhash-assq sym lookup) cdr))))
    (define int->enumerator
      (let ((lookup (alist->vhash (map cons
                                       (%dfe-enum-values (enumerator ...))
                                       (%dfe-enum-symbols (enumerator ...)))
                                  hashv)))
        (lambda (int)
          (and=> (vhash-assv int lookup) cdr))))
    (define-foreign-type enum-name enum-base
      #:wrap-result (lambda (int . _) (int->enumerator int))
      #:unwrap-args enumerator->int)))

(define-syntax %dfe-enum-symbols
  (syntax-rules (=>)
    ((_ (args ...))
     (%dfe-enum-symbols (args ...) ()))
    ((_ (symbol => value args ...) (syms ...))
     (%dfe-enum-symbols (args ...) (syms ... symbol)))
    ((_ (symbol args ...) (syms ...))
     (%dfe-enum-symbols (args ...) (syms ... symbol)))
    ((_ () (syms ...))
     '(syms ...))))

(define-syntax %dfe-enum-values
  (syntax-rules (=>)
    ((_ (args ...))
     (%dfe-enum-values (args ...) () -1))
    ((_ (symbol => value args ...) (vals ...) previous)
     (%dfe-enum-values (args ...) (vals ... value) value))
    ((_ (symbol args ...) (vals ...) previous)
     (%dfe-enum-values (args ...) (vals ... (1+ previous)) (1+ previous)))
    ((_ () (vals ...) previous)
     (list vals ...))))

;;; Pointer types

(define-syntax-rule (define-foreign-pointer-type pointer-name record-type
                      record? pointer->record record->pointer)
  (begin
    (define-wrapped-pointer-type record-type
      record? pointer->record record->pointer
      (lambda (rec port)
        (let ((address (ffi:pointer-address (record->pointer rec))))
          (format port "#<~a 0x~x>" 'pointer-name address))))
    (define-foreign-type pointer-name pointer
      #:wrap-result (lambda (ptr . _) (pointer->record ptr))
      #:unwrap-args record->pointer)))

;;; Function wrappers

(define-syntax-rule (define-foreign-library library path args ...)
  (define library
    (load-foreign-library path args ...)))

(define-syntax-rule (define-foreign-function library
                      (function-name signature ...))
  (define function-name
    (apply wrapped-foreign-library-function library
           (symbol->string 'function-name)
           (%dff-parse-signature (signature ...)))))

(define-syntax %dff-parse-signature
  (syntax-rules (->)
    ((_ (-> return-type) arg-types ...)
     (list #:return-type return-type
           #:arg-types (list arg-types ...)))
    ((_ (next rest ...) arg-types ...)
     (%dff-parse-signature (rest ...) arg-types ... next))))

(define-syntax-rule (define-foreign-functions library
                      (function-name signature ...) ...)
  (begin
    (define-foreign-function library
      (function-name signature ...))
    ...))

(define* (wrapped-foreign-library-function library function-name
                                           #:key return-type arg-types)
  (let* ((wrap-result (c-type-wrapper return-type))
         (unwrappers (map c-type-unwrapper arg-types))
         (unwrap-args (lambda (args)
                        (map (lambda (unwrap arg)
                               (unwrap arg))
                             unwrappers args)))
         (foreign-function
          (foreign-library-function library function-name
                                    #:return-type (c-type-repr return-type)
                                    #:arg-types (map c-type-repr arg-types))))
    (lambda args
      (let* ((raw-args   (unwrap-args args))
             (raw-result (apply foreign-function raw-args))
             (result     (apply wrap-result raw-result args)))
        result))))
